{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Offline_ASR.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_wIWPxBVc3_O"
      },
   "source": [
    "# NeMo offline ASR\n",
    "\n",
    "This notebook demonstrates how to  \n",
    "\n",
    "* transcribe an audio file (offline ASR) with greedy decoder\n",
    "* extract timestamps information from the model to split audio into separate words\n",
    "* use beam search decoder with N-gram language model re-scoring\n",
    "\n",
    "You may find more info on how to train and use language models for ASR models here:\n",
    "https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/asr_language_modeling.html\n"
   ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gzcsqceVdtj3"
      },
      "source": [
        "## Installation\n",
        "NeMo can be installed via simple pip command. \n",
        "\n",
        "Optional CTC beam search decoder might require restart of Colab runtime after installation."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "I9eIxAyKHREB"
      },
      "source": [
        "BRANCH = 'v1.0.0'\n",
        "try:\n",
        "    # Import NeMo Speech Recognition collection\n",
        "    import nemo.collections.asr as nemo_asr\n",
        "except ModuleNotFoundError:\n",
        "    !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n",
        "\n",
        "# check if we have optional Plotly for visualization\n",
        "try:\n",
        "    from plotly import graph_objects as go\n",
        "except ModuleNotFoundError:\n",
        "    !pip install plotly\n",
        "    from plotly import graph_objects as go\n",
        "\n",
	"# check if we have optional ipywidgets for tqdm/notebook\n",
	"try:\n",
        "    import ipywidgets\n",
        "except ModuleNotFoundError:\n",
        "    !pip install ipywidgets\n",
        "\n",
        "# check if CTC beam decoders are installed\n",
        "try:\n",
        "    import ctc_decoders\n",
        "except ModuleNotFoundError:\n",
        "    # install beam search decoder\n",
        "    !apt-get install -y swig\n",
        "    !git clone https://github.com/NVIDIA/NeMo -b \"$BRANCH\"\n",
        "    !cd NeMo && bash scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh\n",
        "    print('Restarting Colab runtime to successfully import built module.')\n",
        "    print('Please re-run the notebook.')\n",
        "    import os\n",
        "    os.kill(os.getpid(), 9)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-X2OyAxreGfl"
      },
      "source": [
        "import numpy as np\n",
        "# Import audio processing library\n",
        "import librosa\n",
        "# We'll use this to listen to audio\n",
        "from IPython.display import Audio, display"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zodyzdyTVXas"
      },
      "source": [
        "## Instantiate pre-trained NeMo model\n",
        "``from_pretrained(...)`` API downloads and initializes model directly from the cloud. \n",
        "\n",
        "Alternatively, ``restore_from(...)`` allows loading a model from a disk.\n",
        "\n",
        "To display available pre-trained models from the cloud, please use ``list_available_models()`` method."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f_J9cuU1H6Bn"
      },
      "source": [
        "nemo_asr.models.EncDecCTCModel.list_available_models()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x2LMVI9qqtEV"
      },
      "source": [
        "Let's load a base English QuartzNet15x5 model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZhWmR7lbvwSm"
      },
      "source": [
        "asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name='QuartzNet15x5Base-En', strict=False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HESTZmIzzCEj"
      },
      "source": [
        "## Get test audio clip"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QPWn89l-zLXo"
      },
      "source": [
        "Let's download and analyze a test audio signal."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "02gDfK7czSVV"
      },
      "source": [
        "# Download audio sample which we'll try\n",
        "# This is a sample from LibriSpeech dev clean subset - the model hasn't seen it before\n",
        "AUDIO_FILENAME = '1919-142785-0028.wav'\n",
        "!wget https://dldata-public.s3.us-east-2.amazonaws.com/1919-142785-0028.wav\n",
        "\n",
        "# load audio signal with librosa\n",
        "signal, sample_rate = librosa.load(AUDIO_FILENAME, sr=None)\n",
        "\n",
        "# display audio player for the signal\n",
        "display(Audio(data=signal, rate=sample_rate))\n",
        "\n",
        "# plot the signal in time domain\n",
        "fig_signal = go.Figure(\n",
        "    go.Scatter(x=np.arange(signal.shape[0])/sample_rate,\n",
        "               y=signal, line={'color': 'green'},\n",
        "               name='Waveform',\n",
        "               hovertemplate='Time: %{x:.2f} s<br>Amplitude: %{y:.2f}<br><extra></extra>'),\n",
        "    layout={\n",
        "        'height': 300,\n",
        "        'xaxis': {'title': 'Time, s'},\n",
        "        'yaxis': {'title': 'Amplitude'},\n",
        "        'title': 'Audio Signal',\n",
        "        'margin': dict(l=0, r=0, t=40, b=0, pad=0),\n",
        "    }\n",
        ")\n",
        "fig_signal.show()\n",
        "\n",
        "# calculate amplitude spectrum\n",
        "time_stride=0.01\n",
        "hop_length = int(sample_rate*time_stride)\n",
        "n_fft = 512\n",
        "# linear scale spectrogram\n",
        "s = librosa.stft(y=signal,\n",
        "                 n_fft=n_fft,\n",
        "                 hop_length=hop_length)\n",
        "s_db = librosa.power_to_db(np.abs(s)**2, ref=np.max, top_db=100)\n",
        "\n",
        "# plot the signal in frequency domain\n",
        "fig_spectrum = go.Figure(\n",
        "    go.Heatmap(z=s_db,\n",
        "               colorscale=[\n",
        "                   [0, 'rgb(30,62,62)'],\n",
        "                   [0.5, 'rgb(30,128,128)'],\n",
        "                   [1, 'rgb(30,255,30)'],\n",
        "               ],\n",
        "               colorbar=dict(\n",
        "                   ticksuffix=' dB'\n",
        "               ),\n",
        "               dx=time_stride, dy=sample_rate/n_fft/1000,\n",
        "               name='Spectrogram',\n",
        "               hovertemplate='Time: %{x:.2f} s<br>Frequency: %{y:.2f} kHz<br>Magnitude: %{z:.2f} dB<extra></extra>'),\n",
        "    layout={\n",
        "        'height': 300,\n",
        "        'xaxis': {'title': 'Time, s'},\n",
        "        'yaxis': {'title': 'Frequency, kHz'},\n",
        "        'title': 'Spectrogram',\n",
        "        'margin': dict(l=0, r=0, t=40, b=0, pad=0),\n",
        "    }\n",
        ")\n",
        "fig_spectrum.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jQSj-IhEhrtI"
      },
      "source": [
        "## Offline inference\n",
        "If we have an entire audio clip available, then we can do offline inference with a pre-trained model to transcribe it.\n",
        "\n",
        "The easiest way to do it is to call ASR model's ``transcribe(...)`` method  that allows transcribing multiple files in a batch."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s0ERrXIzKpwu"
      },
      "source": [
        "# Convert our audio sample to text\n",
        "files = [AUDIO_FILENAME]\n",
        "transcript = asr_model.transcribe(paths2audio_files=files)[0]\n",
        "print(f'Transcript: \"{transcript}\"')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_UOoj-WfQoL_"
      },
      "source": [
        "## Extract timestamps and split words\n",
        "``transcribe()`` generates a text applying a CTC greedy decoder to raw probabilities distribution over alphabet's characters from ASR model. We can get those raw probabilities with ``logprobs=True`` argument."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-0Sk0C9-LmAR"
      },
      "source": [
        "# softmax implementation in NumPy\n",
        "def softmax(logits):\n",
        "    e = np.exp(logits - np.max(logits))\n",
        "    return e / e.sum(axis=-1).reshape([logits.shape[0], 1])\n",
        "\n",
        "# let's do inference once again but without decoder\n",
        "logits = asr_model.transcribe(files, logprobs=True)[0]\n",
        "probs = softmax(logits)\n",
        "\n",
        "# 20ms is duration of a timestep at output of the model\n",
        "time_stride = 0.02\n",
        "\n",
        "# get model's alphabet\n",
        "labels = list(asr_model.decoder.vocabulary) + ['blank']\n",
        "labels[0] = 'space'\n",
        "\n",
        "# plot probability distribution over characters for each timestep\n",
        "fig_probs = go.Figure(\n",
        "    go.Heatmap(z=probs.transpose(),\n",
        "               colorscale=[\n",
        "                   [0, 'rgb(30,62,62)'],\n",
        "                   [1, 'rgb(30,255,30)'],\n",
        "               ],\n",
        "               y=labels,\n",
        "               dx=time_stride,\n",
        "               name='Probs',\n",
        "               hovertemplate='Time: %{x:.2f} s<br>Character: %{y}<br>Probability: %{z:.2f}<extra></extra>'),\n",
        "    layout={\n",
        "        'height': 300,\n",
        "        'xaxis': {'title': 'Time, s'},\n",
        "        'yaxis': {'title': 'Characters'},\n",
        "        'title': 'Character Probabilities',\n",
        "        'margin': dict(l=0, r=0, t=40, b=0, pad=0),\n",
        "    }\n",
        ")\n",
        "fig_probs.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YiNMZBodIaSP"
      },
      "source": [
        "It is easy to identify timesteps for space character."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "32aaW3HEJ89l"
      },
      "source": [
        "# get timestamps for space symbols\n",
        "spaces = []\n",
        "\n",
        "state = ''\n",
        "idx_state = 0\n",
        "\n",
        "if np.argmax(probs[0]) == 0:\n",
        "    state = 'space'\n",
        "\n",
        "for idx in range(1, probs.shape[0]):\n",
        "    current_char_idx = np.argmax(probs[idx])\n",
        "    if state == 'space' and current_char_idx != 0 and current_char_idx != 28:\n",
        "        spaces.append([idx_state, idx-1])\n",
        "        state = ''\n",
        "    if state == '':\n",
        "        if current_char_idx == 0:\n",
        "            state = 'space'\n",
        "            idx_state = idx\n",
        "\n",
        "if state == 'space':\n",
        "    spaces.append([idx_state, len(pred)-1])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rqg4oxpsL8cW"
      },
      "source": [
        "Then we can split original audio signal into separate words. It is worth to mention that all timestamps have a delay (or an offset) depending on the model. We need to take it into account for alignment."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a-LSg9dSL_O1"
      },
      "source": [
        "# calibration offset for timestamps: 180 ms\n",
        "offset = -0.18\n",
        "\n",
        "# split the transcript into words\n",
        "words = transcript.split()\n",
        "\n",
        "# cut words\n",
        "pos_prev = 0\n",
        "for j, spot in enumerate(spaces):\n",
        "    display(words[j])\n",
        "    pos_end = offset + (spot[0]+spot[1])/2*time_stride\n",
        "    display(Audio(signal[int(pos_prev*sample_rate):int(pos_end*sample_rate)],\n",
        "                 rate=sample_rate))\n",
        "    pos_prev = pos_end\n",
        "\n",
        "display(words[j+1])\n",
        "display(Audio(signal[int(pos_prev*sample_rate):],\n",
        "        rate=sample_rate))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q8Jvwe4Ahncx"
      },
      "source": [
        "## Offline inference with beam search decoder and N-gram language model re-scoring\n",
        "\n",
        "It is possible to use an external [KenLM](https://kheafield.com/code/kenlm/)-based N-gram language model to rescore multiple transcription candidates. \n",
        "\n",
        "Let's download and preprocess LibriSpeech 3-gram language model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EIh8wTVs5uH7"
      },
      "source": [
        "import gzip\n",
        "import os, shutil, wget\n",
        "\n",
        "lm_gzip_path = '3-gram.pruned.1e-7.arpa.gz'\n",
        "if not os.path.exists(lm_gzip_path):\n",
        "    print('Downloading pruned 3-gram model.')\n",
        "    lm_url = 'http://www.openslr.org/resources/11/3-gram.pruned.1e-7.arpa.gz'\n",
        "    lm_gzip_path = wget.download(lm_url)\n",
        "    print('Downloaded the 3-gram language model.')\n",
        "else:\n",
        "    print('Pruned .arpa.gz already exists.')\n",
        "\n",
        "uppercase_lm_path = '3-gram.pruned.1e-7.arpa'\n",
        "if not os.path.exists(uppercase_lm_path):\n",
        "    with gzip.open(lm_gzip_path, 'rb') as f_zipped:\n",
        "        with open(uppercase_lm_path, 'wb') as f_unzipped:\n",
        "            shutil.copyfileobj(f_zipped, f_unzipped)\n",
        "    print('Unzipped the 3-gram language model.')\n",
        "else:\n",
        "    print('Unzipped .arpa already exists.')\n",
        "\n",
        "lm_path = 'lowercase_3-gram.pruned.1e-7.arpa'\n",
        "if not os.path.exists(lm_path):\n",
        "    with open(uppercase_lm_path, 'r') as f_upper:\n",
        "        with open(lm_path, 'w') as f_lower:\n",
        "            for line in f_upper:\n",
        "                f_lower.write(line.lower())\n",
        "print('Converted language model file to lowercase.')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fLDbUkzzUAqW"
      },
      "source": [
        "Let's instantiate ``BeamSearchDecoderWithLM`` module."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_qgKa9L954bJ"
      },
      "source": [
        "beam_search_lm = nemo_asr.modules.BeamSearchDecoderWithLM(\n",
        "    vocab=list(asr_model.decoder.vocabulary),\n",
        "    beam_width=16,\n",
        "    alpha=2, beta=1.5,\n",
        "    lm_path=lm_path,\n",
        "    num_cpus=max(os.cpu_count(), 1),\n",
        "    input_tensor=False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NSH8EvL7USac"
      },
      "source": [
        "Now we can check all transcription candidates along with their scores."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nV1CAy0Dit-g"
      },
      "source": [
        "beam_search_lm.forward(log_probs = np.expand_dims(probs, axis=0), log_probs_length=None)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ubpcxp6z3ZF-"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
